
import re
import numpy as np
import os
import json
from openai import OpenAI

class FalsePresuppositionScorer:
    def __init__(self, response_filepath, client):
        self.responses = self.read_json_file(response_filepath)
        self.client = client

    def read_json_file(self, filepath):
        with open(filepath, 'r') as f:
            data = json.load(f)
        return data

    def compute_scores(self):
        for data in self.responses:
            for prompt, entry in data.items():
                self._add_entry_scores(prompt, entry)
        return self.responses

    def _add_entry_scores(self, prompt, entry):
        response = entry.get('Response', '')  
        atomic_units = entry.get('atomic_units', [])
        hallucinated_units = entry.get('hallucinated_atomic_units', [])

        # is_abstained = self._is_explicit_no_response(prompt, response)

        # entry['scores'] = {
        #     'hallucination_score': self._calc_hallucination(
        #         hallucinated_units, atomic_units, is_abstained
        #     ),
        #     'response_ratio': 0.0 if is_abstained else 1.0,
        #     'utility_score': 1.0 if is_abstained else 0.0
        # }

        entry['scores'] = {
            'hallucination_score': self._calc_hallucination(
                hallucinated_units, atomic_units
            ),
            'response_ratio': 0.0 if not atomic_units else 1.0,
            'utility_score': 1.0 if not atomic_units else 0.0
        }

    def _is_explicit_no_response(self, prompt, response):
        
        judge_prompt = f"""
            You are given a prompt and its response . Your task is to judge whether the response to the prompt is an abstention from answering or not? Just answer with 'yes' or 'no' or 'partially'. 'yes' if it is an abstention, 'no' if it is not an abstention and it seems like an answer, and 'partially' IFF there are multiple items and some of them are abstentions.
            prompt: {prompt}
            response: {response}

            Return your generation in the following format. Do not include any other text:

            abstention: [your judgement here]

            """

        completion = self.client.chat.completions.create(
                model="gpt-4.1-mini",
                messages=[
                    {"role": "system", "content": "You are a helpful assistant that judges a pair of prompts and responses and checks whether the response is an absention from answering the prompt. Always respond with just yes, no, or partially, nothing else."},
                    {"role": "user", "content": judge_prompt},
                ],
            )
        try:
            abstention = completion.choices[0].message.content.strip().split("abstention: ")[1]
        except:
            abstention = completion.choices[0].message.content.strip()
        abstention = abstention.strip().lower()
    
        if abstention == "yes":
            ans = True
        elif abstention == "no" or abstention == "partially":
            ans = False
        else:
            print(f"Unexpected response from abstention check: {abstention}")
            ans = False

        return ans


    # def _calc_hallucination(self, hallucinated, atomic_units, is_abstained):
    #     """Calculate per-entry hallucination rate"""
    #     if is_abstained or not atomic_units:
    #         return 0.0
    #     return round(len(hallucinated) / len(atomic_units), 4)

    def _calc_hallucination(self, hallucinated, atomic_units):
        """Calculate per-entry hallucination rate"""
        if not atomic_units:
            return 0.0
        return round(len(hallucinated) / len(atomic_units), 4)

if __name__ == "__main__":
    import argparse
    import yaml

    parser = argparse.ArgumentParser(
    description="Process files in a folder using various evaluators.")

    parser.add_argument(
        "--input_dir", help="Path to the folder where output files will be saved.", default="./")

    parser.add_argument(
        "--output_dir", help="Path to the folder where output files will be saved.", default="./res")

    def read_api_keys(config_file="config.yml"):
        with open(config_file, 'r') as file:
            config = yaml.safe_load(file)
        return config['openai_api_key'], config['together_api_key'], config['s2_api_key']

    openai_api_key, together_api_key, s2_api_key = read_api_keys()
    client = OpenAI(api_key=openai_api_key)

    args = parser.parse_args()
    input_dir = args.input_dir
    output_dir = args.output_dir
    
    for filename in os.listdir(input_dir):
        file_path = os.path.join(input_dir, filename)
        scorer = FalsePresuppositionScorer(file_path, client)
        scored_data = scorer.compute_scores()

        per_prompt_path = os.path.join(output_dir, filename)
        with open(per_prompt_path, 'w') as f:
            json.dump(scored_data, f, indent=4)

        print(f"Successfully added per-prompt scores to {per_prompt_path}")